# Copyright 2020 The TensorStore Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorstore.KvStore."""

import copy
import pathlib
import pickle
import tempfile
import threading
import time
from typing import Callable

import pytest
import tensorstore as ts


def test_instantiation() -> None:
  with pytest.raises(TypeError):
    ts.KvStore()


def test_spec_pickle() -> None:
  kv_spec = ts.KvStore.Spec('memory://')
  assert ts.KvStore.Spec.__module__ == 'tensorstore'
  assert ts.KvStore.Spec.__qualname__ == 'KvStore.Spec'
  assert pickle.loads(pickle.dumps(kv_spec)).to_json() == kv_spec.to_json()


def test_pickle() -> None:
  kv = ts.KvStore.open('memory://').result()
  assert pickle.loads(pickle.dumps(kv)).url == 'memory://'


def test_copy() -> None:
  with tempfile.TemporaryDirectory() as dir_path:
    spec = {
        'driver': 'file',
        'path': dir_path,
    }
    t1 = ts.KvStore.open(spec).result()

    t2 = copy.copy(t1)

    assert t1 is not t2

    t3 = copy.deepcopy(t1)

    t1['abc'] = b'def'
    assert t1['abc'] == b'def'
    assert t2['abc'] == b'def'
    assert t3['abc'] == b'def'


def test_keyrange() -> None:
  r = ts.KvStore.KeyRange('a', 'b')
  assert repr(r) == "KvStore.KeyRange(b'a', b'b')"


def test_copy_memory() -> None:
  spec = {
      'driver': 'memory',
  }
  t1 = ts.KvStore.open(spec).result()

  t2 = copy.copy(t1)

  assert t1 is not t2

  t3 = copy.deepcopy(t1)

  t1['abc'] = b'def'
  assert t1['abc'] == b'def'
  assert t2['abc'] == b'def'
  with pytest.raises(KeyError):
    _ = t3['abc']


def test_copy_range_to_ocdbt_memory() -> None:
  context = ts.Context()
  for k in ['a', 'b', 'c']:
    child = ts.KvStore.open(
        {'driver': 'ocdbt', 'base': f'memory://host_{k}/'}, context=context
    ).result()
    child[k] = f'value_{k}'
  parent = ts.KvStore.open(
      {'driver': 'ocdbt', 'base': 'memory://'}, context=context
  ).result()
  for k in ['a', 'b', 'c']:
    child = ts.KvStore.open(
        {'driver': 'ocdbt', 'base': f'memory://host_{k}/'}, context=context
    ).result()
    child.experimental_copy_range_to(parent).result()
  assert parent.list().result() == [b'a', b'b', b'c']


def test_copy_range_to_ocdbt_file() -> None:
  context = ts.Context()
  with tempfile.TemporaryDirectory() as dir_path:
    base_url = pathlib.Path(dir_path).resolve().as_uri()
    child_spec = {
        'driver': 'ocdbt',
        'base': f'{base_url}/child',
    }
    child = ts.KvStore.open(child_spec, context=context).result()
    for k in ['a', 'b', 'c']:
      child[k] = f'value_{k}'

    parent_spec = {
        'driver': 'ocdbt',
        'base': base_url,
    }
    parent = ts.KvStore.open(parent_spec, context=context).result()
    child.experimental_copy_range_to(parent).result()

    assert parent.list().result() == [b'a', b'b', b'c']


def test_copy_range_to_memory_fails() -> None:
  context = ts.Context()
  child = ts.KvStore.open('memory://child/', context=context).result()
  for k in ['a', 'b', 'c']:
    child[k] = f'value_{k}'
  parent = ts.KvStore.open('memory://', context=context).result()
  with pytest.raises(NotImplementedError):
    child.experimental_copy_range_to(parent).result()


def test_copy_range_to_file_fails() -> None:
  context = ts.Context()
  with tempfile.TemporaryDirectory() as dir_path:
    base_url = pathlib.Path(dir_path).resolve().as_uri()
    child = ts.KvStore.open(f'{base_url}/child/', context=context).result()
    for k in ['a', 'b', 'c']:
      child[k] = f'value_{k}'
    parent = ts.KvStore.open(base_url, context=context).result()
    with pytest.raises(NotImplementedError):
      child.experimental_copy_range_to(parent).result()


def test_copy_range_to_ocdbt_memory_bad_path() -> None:
  context = ts.Context()
  child = ts.KvStore.open(
      {'driver': 'ocdbt', 'base': 'memory://child/'}, context=context
  ).result()
  for k in ['a', 'b', 'c']:
    child[k] = f'value_{k}'
  parent = ts.KvStore.open(
      {'driver': 'ocdbt', 'base': 'memory://c'}, context=context
  ).result()
  with pytest.raises(NotImplementedError):
    child.experimental_copy_range_to(parent).result()


def _run_threads(
    stop: threading.Event,
    read_props: Callable[[], None],
    update_props: Callable[[], None],
) -> None:
  threads = []
  for _ in range(4):
    threads.append(threading.Thread(target=read_props))
    threads.append(threading.Thread(target=update_props))

  for t in threads:
    t.start()

  time.sleep(0.3)
  stop.set()

  for t in threads:
    t.join()


def test_kvstore_concurrent() -> None:
  """Validates that concurrent updates and reads to a KvStore do not crash."""
  kv = ts.KvStore.open('memory://').result()
  stop = threading.Event()

  def read_props() -> None:
    while not stop.is_set():
      _ = kv.path
      _ = kv.url
      _ = kv.base
      _ = kv.transaction
      _ = kv.spec()
      _ = kv.list().result()
      _ = kv == ts.KvStore.open('memory://').result()
      _ = f'{kv}'
      _ = repr(kv)
      _ = kv / 'foo'
      _ = kv + 'bar'

  def update_props() -> None:
    time.sleep(0.01)
    i = 0
    txn = ts.Transaction()
    while not stop.is_set():
      if (i % 4) == 0:
        kv.path = 'abc'
      elif (i % 4) == 1:
        kv.path = ''
      elif (i % 4) == 2:
        kv.transaction = txn
      else:
        kv.transaction = None
      i += 1

  _run_threads(stop, read_props, update_props)


def test_kvstore_spec_concurrent() -> None:
  """Validates that concurrent updates and reads do not crash."""
  s = ts.KvStore.Spec('memory://')
  stop = threading.Event()

  def read_props() -> None:
    while not stop.is_set():
      _ = s.path
      _ = s.url
      _ = s.base
      _ = s == ts.KvStore.Spec('memory://')
      _ = s.to_json()
      _ = f'{s}'
      _ = repr(s)

  def update_props() -> None:
    time.sleep(0.01)
    i = 0
    while not stop.is_set():
      if (i % 2) == 0:
        s.path = 'abc/'
      else:
        s.path = 'def/'
      i += 1

  _run_threads(stop, read_props, update_props)


def test_kvstore_keyrange_concurrent() -> None:
  """Tests concurrent access to KvStore.KeyRange properties."""
  kr = ts.KvStore.KeyRange('a', 'z')
  stop = threading.Event()

  def read_props() -> None:
    while not stop.is_set():
      _ = kr.inclusive_min
      _ = kr.exclusive_max
      _ = kr.empty
      _ = kr == ts.KvStore.KeyRange(b'a', b'z')
      _ = f'{kr}'
      _ = repr(kr)

  def update_props() -> None:
    time.sleep(0.01)
    while not stop.is_set():
      kr.inclusive_min = b'b'
      kr.exclusive_max = 'y'
      kr.inclusive_min = 'a'
      kr.exclusive_max = b'z'

  _run_threads(stop, read_props, update_props)


def test_kvstore_timestampedstorategeneration_concurrent() -> None:
  """Tests concurrent access to KvStore.TimestampedStorageGeneration properties."""
  tsg = ts.KvStore.TimestampedStorageGeneration()
  stop = threading.Event()

  def read_props() -> None:
    while not stop.is_set():
      _ = tsg.generation
      _ = tsg.time
      _ = tsg == ts.KvStore.TimestampedStorageGeneration()
      _ = f'{tsg}'
      _ = repr(tsg)

  def update_props() -> None:
    time.sleep(0.01)
    while not stop.is_set():
      tsg.generation = 'gen'
      tsg.time = 1.0
      tsg.generation = b''
      tsg.time = -float('inf')

  _run_threads(stop, read_props, update_props)


def test_kvstore_readresult_concurrent() -> None:
  """Tests concurrent access to KvStore.ReadResult properties."""
  gen = ts.KvStore.TimestampedStorageGeneration(b'gen', 1.0)
  rr = ts.KvStore.ReadResult('value', b'foo', gen)
  stop = threading.Event()

  def read_props() -> None:
    while not stop.is_set():
      _ = rr == ts.KvStore.ReadResult()
      _ = rr.state
      _ = rr.value
      _ = rr.stamp
      _ = rr == ts.KvStore.ReadResult('value', 'bar', gen)
      _ = f'{rr}'
      _ = repr(rr)

  def update_props() -> None:
    time.sleep(0.01)
    while not stop.is_set():
      rr.state = 'value'
      rr.value = b'value'
      rr.stamp = ts.KvStore.TimestampedStorageGeneration(b'gen', 1.0)
      rr.state = 'missing'
      rr.value = b''
      rr.stamp = ts.KvStore.TimestampedStorageGeneration()

  _run_threads(stop, read_props, update_props)
